{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "EgiF12Hf1Dhs"
   },
   "source": [
    "This notebook provides examples to go along with the [textbook](http://manipulation.csail.mit.edu/pose.html).  I recommend having both windows open, side-by-side!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "eeMrMI0-1Dhu"
   },
   "outputs": [],
   "source": [
    "from functools import partial\n",
    "\n",
    "import matplotlib.animation as animation\n",
    "import matplotlib.pyplot as plt\n",
    "import mpld3\n",
    "import numpy as np\n",
    "from IPython.display import HTML, display\n",
    "from pydrake.all import (\n",
    "    CsdpSolver,\n",
    "    MathematicalProgram,\n",
    "    PointCloud,\n",
    "    RigidTransform,\n",
    "    RotationMatrix,\n",
    "    Solve,\n",
    "    ge,\n",
    ")\n",
    "from scipy.spatial import KDTree\n",
    "\n",
    "from manipulation import running_as_notebook\n",
    "\n",
    "if running_as_notebook:\n",
    "    mpld3.enable_notebook()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Point cloud registration with known correspondences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def MakeRandomObjectModelAndScenePoints(\n",
    "    num_model_points=20,\n",
    "    noise_std=0,\n",
    "    num_outliers=0,\n",
    "    yaw_O=None,\n",
    "    p_O=None,\n",
    "    num_viewable_points=None,\n",
    "    seed=None,\n",
    "):\n",
    "    \"\"\"Returns p_Om, p_s\"\"\"\n",
    "    rng = np.random.default_rng(seed)\n",
    "\n",
    "    # Make a random set of points to define our object in the x,y plane\n",
    "    theta = np.arange(0, 2.0 * np.pi, 2.0 * np.pi / num_model_points)\n",
    "    l = (\n",
    "        1.0\n",
    "        + 0.5 * np.sin(2.0 * theta)\n",
    "        + 0.4 * rng.random((1, num_model_points))\n",
    "    )\n",
    "    p_Om = np.vstack((l * np.sin(theta), l * np.cos(theta), 0 * l))\n",
    "\n",
    "    # Make a random object pose if one is not specified, and apply it to get the scene points.\n",
    "    if p_O is None:\n",
    "        p_O = [2.0 * rng.random(), 2.0 * rng.random(), 0.0]\n",
    "    if len(p_O) == 2:\n",
    "        p_O.append(0.0)\n",
    "    if yaw_O is None:\n",
    "        yaw_O = 0.5 * rng.random()\n",
    "    X_O = RigidTransform(RotationMatrix.MakeZRotation(yaw_O), p_O)\n",
    "    if num_viewable_points is None:\n",
    "        num_viewable_points = num_model_points\n",
    "    assert num_viewable_points <= num_model_points\n",
    "    p_s = X_O.multiply(p_Om[:, :num_viewable_points])\n",
    "    p_s[:2, :] += rng.normal(scale=noise_std, size=(2, num_viewable_points))\n",
    "    if num_outliers:\n",
    "        outliers = rng.uniform(low=-1.5, high=3.5, size=(3, num_outliers))\n",
    "        outliers[2, :] = 0\n",
    "        p_s = np.hstack((p_s, outliers))\n",
    "\n",
    "    return p_Om, p_s, X_O\n",
    "\n",
    "\n",
    "def MakeRectangleModelAndScenePoints(\n",
    "    num_points_per_side=7,\n",
    "    noise_std=0,\n",
    "    num_outliers=0,\n",
    "    yaw_O=None,\n",
    "    p_O=None,\n",
    "    num_viewable_points=None,\n",
    "    seed=None,\n",
    "):\n",
    "    rng = np.random.default_rng(seed)\n",
    "    if p_O is None:\n",
    "        p_O = [2.0 * rng.random(), 2.0 * rng.random(), 0.0]\n",
    "    if len(p_O) == 2:\n",
    "        p_O.append(0.0)\n",
    "    if yaw_O is None:\n",
    "        yaw_O = 0.5 * rng.random()\n",
    "    X_O = RigidTransform(RotationMatrix.MakeZRotation(yaw_O), p_O)\n",
    "    if num_viewable_points is None:\n",
    "        num_viewable_points = 4 * num_points_per_side\n",
    "\n",
    "    x = np.arange(-1, 1, 2 / num_points_per_side)\n",
    "    half_width = 2\n",
    "    half_height = 1\n",
    "    top = np.vstack((half_width * x, half_height + 0 * x))\n",
    "    right = np.vstack((half_width + 0 * x, -half_height * x))\n",
    "    bottom = np.vstack((-half_width * x, -half_height + 0 * x))\n",
    "    left = np.vstack((-half_width + 0 * x, half_height * x))\n",
    "    p_Om = np.vstack(\n",
    "        (\n",
    "            np.hstack((top, right, bottom, left)),\n",
    "            np.zeros((1, 4 * num_points_per_side)),\n",
    "        )\n",
    "    )\n",
    "    p_s = X_O.multiply(p_Om[:, :num_viewable_points])\n",
    "    p_s[:2, :] += rng.normal(scale=noise_std, size=(2, num_viewable_points))\n",
    "    if num_outliers:\n",
    "        outliers = rng.uniform(low=-1.5, high=3.5, size=(3, num_outliers))\n",
    "        outliers[2, :] = 0\n",
    "        p_s = np.hstack((p_s, outliers))\n",
    "\n",
    "    return p_Om, p_s, X_O\n",
    "\n",
    "\n",
    "def PlotEstimate(\n",
    "    p_Om, p_s, Xhat_O=RigidTransform(), chat=None, X_O=None, ax=None\n",
    "):\n",
    "    p_m = Xhat_O.multiply(p_Om)\n",
    "    if ax is None:\n",
    "        ax = plt.subplot()\n",
    "    Nm = p_Om.shape[1]\n",
    "    artists = ax.plot(p_m[0, :], p_m[1, :], \"bo\")\n",
    "    artists += ax.fill(p_m[0, :], p_m[1, :], \"lightblue\", alpha=0.5)\n",
    "    artists += ax.plot(p_s[0, :], p_s[1, :], \"ro\")\n",
    "    if chat is not None:\n",
    "        artists += ax.plot(\n",
    "            np.vstack((p_m[0, chat], p_s[0, :])),\n",
    "            np.vstack((p_m[1, chat], p_s[1, :])),\n",
    "            \"g--\",\n",
    "        )\n",
    "    if X_O:\n",
    "        p_s = X_O.multiply(p_Om)\n",
    "    artists += ax.fill(p_s[0, :Nm], p_s[1, :Nm], \"lightsalmon\")\n",
    "    ax.axis(\"equal\")\n",
    "    return artists\n",
    "\n",
    "\n",
    "def PrintResults(X_O, Xhat_O):\n",
    "    p = X_O.translation()\n",
    "    aa = X_O.rotation().ToAngleAxis()\n",
    "    print(f\"True position: {p}\")\n",
    "    print(f\"True orientation: {aa}\")\n",
    "    p = Xhat_O.translation()\n",
    "    aa = Xhat_O.rotation().ToAngleAxis()\n",
    "    print(f\"Estimated position: {p}\")\n",
    "    print(f\"Estimated orientation: {aa}\")\n",
    "\n",
    "\n",
    "def PoseEstimationGivenCorrespondences(p_Om, p_s, chat):\n",
    "    \"\"\"Returns optimal X_O given the correspondences\"\"\"\n",
    "    # Apply correspondences, and transpose data to support numpy broadcasting\n",
    "    p_Omc = p_Om[:, chat].T\n",
    "    p_s = p_s.T\n",
    "\n",
    "    # Calculate the central points\n",
    "    p_Ombar = p_Omc.mean(axis=0)\n",
    "    p_sbar = p_s.mean(axis=0)\n",
    "\n",
    "    # Calculate the \"error\" terms, and form the data matrix\n",
    "    merr = p_Omc - p_Ombar\n",
    "    serr = p_s - p_sbar\n",
    "    W = np.matmul(serr.T, merr)\n",
    "\n",
    "    # Compute R\n",
    "    U, Sigma, Vt = np.linalg.svd(W)\n",
    "    R = np.matmul(U, Vt)\n",
    "    if np.linalg.det(R) < 0:\n",
    "        print(\"fixing improper rotation\")\n",
    "        Vt[-1, :] *= -1\n",
    "        R = np.matmul(U, Vt)\n",
    "\n",
    "    # Compute p\n",
    "    p = p_sbar - np.matmul(R, p_Ombar)\n",
    "\n",
    "    return RigidTransform(RotationMatrix(R), p)\n",
    "\n",
    "\n",
    "p_Om, p_s, X_O = MakeRandomObjectModelAndScenePoints(num_model_points=20)\n",
    "# p_Om, p_s, X_O = MakeRectangleModelAndScenePoints()\n",
    "Xhat = RigidTransform()\n",
    "c = range(p_Om.shape[1])  # perfect, known correspondences\n",
    "fig, ax = plt.subplots(1, 2)\n",
    "PlotEstimate(p_Om, p_s, Xhat, c, ax=ax[0])\n",
    "Xhat = PoseEstimationGivenCorrespondences(p_Om, p_s, c)\n",
    "ax[1].set_xlim(ax[0].get_xlim())\n",
    "ax[1].set_ylim(ax[0].get_ylim())\n",
    "PlotEstimate(p_Om, p_s, Xhat, c, ax=ax[1])\n",
    "ax[0].set_title(\"Original Data\")\n",
    "ax[1].set_title(\"After Registration\")\n",
    "PrintResults(X_O, Xhat)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "AHfxMwrvb1mz"
   },
   "source": [
    "# Iterative Closest Point (ICP)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "N2cYjTpub1m0",
    "tags": []
   },
   "outputs": [],
   "source": [
    "def FindClosestPoints(point_cloud_A, point_cloud_B):\n",
    "    \"\"\"\n",
    "    Finds the nearest (Euclidean) neighbor in point_cloud_B for each\n",
    "    point in point_cloud_A.\n",
    "    @param point_cloud_A A 3xN numpy array of points.\n",
    "    @param point_cloud_B A 3xN numpy array of points.\n",
    "    @return indices An (N, ) numpy array of the indices in point_cloud_B of each\n",
    "        point_cloud_A point's nearest neighbor.\n",
    "    \"\"\"\n",
    "    indices = np.empty(point_cloud_A.shape[1], dtype=int)\n",
    "\n",
    "    kdtree = KDTree(point_cloud_B.T)\n",
    "    for i in range(point_cloud_A.shape[1]):\n",
    "        distance, indices[i] = kdtree.query(point_cloud_A[:, i], k=1)\n",
    "\n",
    "    return indices\n",
    "\n",
    "\n",
    "def IterativeClosestPoint(p_Om, p_s, X_O=None, animate=True):\n",
    "    Xhat = RigidTransform()\n",
    "    Nm = p_s.shape[1]\n",
    "    chat_previous = (\n",
    "        np.zeros(Nm) - 1\n",
    "    )  # Set chat to a value that FindClosePoints will never return.\n",
    "\n",
    "    if animate:\n",
    "        fig, ax = plt.subplots()\n",
    "        frames = []\n",
    "        frames.append(\n",
    "            PlotEstimate(\n",
    "                p_Om=p_Om, p_s=p_s, Xhat_O=Xhat, chat=None, X_O=X_O, ax=ax\n",
    "            )\n",
    "        )\n",
    "\n",
    "    while True:\n",
    "        chat = FindClosestPoints(p_s, Xhat.multiply(p_Om))\n",
    "        if np.array_equal(chat, chat_previous):\n",
    "            # Then I've converged.\n",
    "            break\n",
    "        chat_previous = chat\n",
    "        if animate:\n",
    "            frames.append(\n",
    "                PlotEstimate(\n",
    "                    p_Om=p_Om, p_s=p_s, Xhat_O=Xhat, chat=chat, X_O=X_O, ax=ax\n",
    "                )\n",
    "            )\n",
    "        Xhat = PoseEstimationGivenCorrespondences(p_Om, p_s, chat)\n",
    "        if animate:\n",
    "            frames.append(\n",
    "                PlotEstimate(\n",
    "                    p_Om=p_Om, p_s=p_s, Xhat_O=Xhat, chat=None, X_O=X_O, ax=ax\n",
    "                )\n",
    "            )\n",
    "\n",
    "    if animate:\n",
    "        ani = animation.ArtistAnimation(\n",
    "            fig, frames, interval=400, repeat=False\n",
    "        )\n",
    "\n",
    "        display(HTML(ani.to_jshtml()))\n",
    "        plt.close()\n",
    "\n",
    "    if X_O:\n",
    "        PrintResults(X_O, Xhat)\n",
    "\n",
    "    return Xhat, chat\n",
    "\n",
    "\n",
    "p_Om, p_s, X_O = MakeRandomObjectModelAndScenePoints(num_model_points=20)\n",
    "IterativeClosestPoint(p_Om, p_s, X_O);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "WLMLxAJJb1m2"
   },
   "source": [
    "Try increasing the standard deviation on yaw in the example above.  At some point, the performance can get pretty poor!\n",
    "\n",
    "# ICP with messy point clouds\n",
    "\n",
    "Try changing the amount of noise, the number of outliers, and/or the partial views.  There are not particularly good theorems here, but I hope that a little bit of play will get you a lot of intuition."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "p_Om, p_s, X_O = MakeRectangleModelAndScenePoints(\n",
    "    #    noise_std=0.2, # adds noise to each scene point (default is 0.0)\n",
    "    #    num_outliers=3, # adds random points from a uniform distribution\n",
    "    #    num_viewable_points=9, # only this number of model points appear in the scene points\n",
    "    yaw_O=0.2,  # object orientation (comment it out for random rotations)\n",
    "    p_O=[1, 2],  # object position (comment it out for random positions)\n",
    ")\n",
    "IterativeClosestPoint(p_Om, p_s, X_O);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Is least-squares the right cost function?\n",
    "\n",
    "Here is a particular setup that is interesting.  The configuration I've given you below results in ICP getting stuck in a local minima.  You will find that the system converges to this local minima from a wide variety of initial conditions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "p_Om, p_s, X_O = MakeRectangleModelAndScenePoints(\n",
    "    num_viewable_points=9,\n",
    "    yaw_O=0.2,\n",
    "    p_O=[1, 2],\n",
    ")\n",
    "IterativeClosestPoint(p_Om, p_s, X_O);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Non-penetration constraints with nonlinear optimization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def ConstrainedKnownCorrespondenceNonlinearOptimization(p_Om, p_s, chat):\n",
    "    \"\"\"This version adds a non-penetration constraint (x,y >= 0)\"\"\"\n",
    "\n",
    "    p_Omc = p_Om[:2, chat]\n",
    "    p_s = p_s[:2, :]\n",
    "    Ns = p_s.shape[1]\n",
    "\n",
    "    prog = MathematicalProgram()\n",
    "    p = prog.NewContinuousVariables(2, \"p\")\n",
    "    theta = prog.NewContinuousVariables(1, \"theta\")\n",
    "\n",
    "    def position_model_in_world(vars, i):\n",
    "        [p, theta] = np.split(vars, [2])\n",
    "        R = np.array(\n",
    "            [\n",
    "                [np.cos(theta[0]), -np.sin(theta[0])],\n",
    "                [np.sin(theta[0]), np.cos(theta[0])],\n",
    "            ]\n",
    "        )\n",
    "        p_Wmci = p + R @ p_Omc[:, i]\n",
    "        return p_Wmci\n",
    "\n",
    "    def squared_distance(vars, i):\n",
    "        p_Wmci = position_model_in_world(vars, i)\n",
    "        err = p_Wmci - p_s[:, i]\n",
    "        return err.dot(err)\n",
    "\n",
    "    for i in range(Ns):\n",
    "        # forall i, |p + R*p_Omi - p_si|²\n",
    "        prog.AddCost(\n",
    "            partial(squared_distance, i=i), np.concatenate([p[:], theta])\n",
    "        )\n",
    "        # forall i, p + R*p_Omi >= 0.\n",
    "        prog.AddConstraint(\n",
    "            partial(position_model_in_world, i=i),\n",
    "            vars=np.concatenate([p[:], theta]),\n",
    "            lb=[0, 0],\n",
    "            ub=[np.inf, np.inf],\n",
    "        )\n",
    "\n",
    "    result = Solve(prog)\n",
    "\n",
    "    theta_sol = result.GetSolution(theta[0])\n",
    "    Rsol = np.array(\n",
    "        [\n",
    "            [np.cos(theta_sol), -np.sin(theta_sol), 0],\n",
    "            [np.sin(theta_sol), np.cos(theta_sol), 0],\n",
    "            [0, 0, 1],\n",
    "        ]\n",
    "    )\n",
    "    psol = np.zeros(3)\n",
    "    psol[:2] = result.GetSolution(p)\n",
    "\n",
    "    return RigidTransform(RotationMatrix(Rsol), psol)\n",
    "\n",
    "\n",
    "p_Om, p_s, X_O = MakeRectangleModelAndScenePoints(\n",
    "    yaw_O=0.2,\n",
    "    p_O=[1.5, 1.2],\n",
    ")\n",
    "c = range(p_Om.shape[1])  # perfect, known correspondences\n",
    "Xhat_O = ConstrainedKnownCorrespondenceNonlinearOptimization(p_Om, p_s, c)\n",
    "PlotEstimate(p_Om=p_Om, p_s=p_s, Xhat_O=Xhat_O, chat=c, X_O=X_O)\n",
    "PrintResults(X_O, Xhat_O)\n",
    "plt.gca().plot([0, 0], [0, 2.5], \"g-\", linewidth=3)\n",
    "plt.gca().plot([0, 4], [0, 0], \"g-\", linewidth=3);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Non-penetration (half-plane) constraints with convex optimization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def ConstrainedKnownCorrespondenceConvexRelaxation(p_Om, p_s, chat):\n",
    "    \"\"\"This version adds a non-penetration constraint (x,y >= 0)\"\"\"\n",
    "\n",
    "    p_Omc = p_Om[:2, chat]\n",
    "    p_s = p_s[:2, :]\n",
    "    Ns = p_s.shape[1]\n",
    "\n",
    "    prog = MathematicalProgram()\n",
    "    [a, b] = prog.NewContinuousVariables(2)\n",
    "    # We use the slack variable as an upper bound on the cost of each point to make the objective linear.\n",
    "    slack = prog.NewContinuousVariables(Ns)\n",
    "    p = prog.NewContinuousVariables(2)\n",
    "    prog.AddBoundingBoxConstraint(0, 1, [a, b])  # This makes Csdp happier\n",
    "    R = np.array([[a, -b], [b, a]])\n",
    "    prog.AddLorentzConeConstraint([1.0, a, b])\n",
    "\n",
    "    # Note: Could do this more efficiently, exploiting trace.  But I'm keeping it simpler here.\n",
    "    prog.AddCost(np.sum(slack))\n",
    "    for i in range(Ns):\n",
    "        c = p + np.matmul(R, p_Omc[:, i]) - p_s[:, i]\n",
    "        # forall i, slack[i]^2 >= |c|^2\n",
    "        prog.AddLorentzConeConstraint([slack[i], c[0], c[1]])\n",
    "        # forall i, p + R*mi >= 0.\n",
    "        prog.AddConstraint(ge(p + np.matmul(R, p_Omc[:, i]), [0, 0]))\n",
    "\n",
    "    result = CsdpSolver().Solve(prog)\n",
    "\n",
    "    [a, b] = result.GetSolution([a, b])\n",
    "    Rsol = np.array([[a, -b, 0], [b, a, 0], [0, 0, 1]])\n",
    "    psol = np.zeros(3)\n",
    "    psol[:2] = result.GetSolution(p)\n",
    "\n",
    "    return RigidTransform(RotationMatrix(Rsol), psol)\n",
    "\n",
    "\n",
    "p_Om, p_s, X_O = MakeRectangleModelAndScenePoints(\n",
    "    yaw_O=0.2,\n",
    "    p_O=[1.5, 1.2],\n",
    ")\n",
    "c = range(p_Om.shape[1])  # perfect, known correspondences\n",
    "Xhat_O = ConstrainedKnownCorrespondenceConvexRelaxation(p_Om, p_s, c)\n",
    "PlotEstimate(p_Om=p_Om, p_s=p_s, Xhat_O=Xhat_O, chat=c, X_O=X_O)\n",
    "PrintResults(X_O, Xhat_O)\n",
    "plt.gca().plot([0, 0], [0, 2.5], \"g-\", linewidth=3)\n",
    "plt.gca().plot([0, 4], [0, 0], \"g-\", linewidth=3);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "Robotic Manipulation - Geometric Pose Estimation.ipynb",
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3.10.6 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8 (main, Oct 13 2022, 09:48:40) [Clang 14.0.0 (clang-1400.0.29.102)]"
  },
  "metadata": {
   "interpreter": {
    "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
   }
  },
  "vscode": {
   "interpreter": {
    "hash": "b0fa6594d8f4cbf19f97940f81e996739fb7646882a419484c72d19e05852a7e"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
